import os
from omegaconf import OmegaConf
# from trainers.trainer_simple import LitModel
import importlib
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
import wandb
import pytorch_lightning as pl

def run_conf(conf):
    '''
    It runs the training loop for the model.
    
    :param conf: the OmegaConf object that contains all the parameters
    '''
    log_model = conf.wandb.log_model
    if conf.get('debug', False) or conf.get('offline', False):
        log_model = False
    exp_name = '.'.join(conf.conf_path.split('/')[-2:]) # left only subfolder names
    exp_name = '.'.join(exp_name.split('.')[:-1]) # avoid .yaml extension
    exp_name = conf.get('exp_name', exp_name)
    kwargs = {}
    if conf.get('debug', False):
        kwargs = {'mode':'disabled'}
    if conf.get('offline', False):
        kwargs = {'mode':'offline'}
    os.makedirs('wandb_log', exist_ok=True)
    wandb_logger = WandbLogger(
        name=exp_name, project=conf.wandb.project, 
        log_model=log_model, save_dir='wandb_log', **kwargs
    )
    wandb_logger.log_hyperparams(conf)
    print(OmegaConf.to_yaml(conf))
    trainer_name = conf.get('trainer_name', 'trainer_simple')
    plModel = importlib.import_module(f'trainers.{trainer_name}').LitModel
    model = plModel(conf)
    if conf.get('ckpt', '') != '':
        model_at = wandb_logger.experiment.use_artifact(conf.ckpt)
        model_dir = model_at.download()
        model = model.load_from_checkpoint(os.path.join(model_dir, 'model.ckpt'))
        print('loaded!', flush=True)
    if not conf.get('debug', False):
        wandb_logger.watch(model.nerf, log='all', log_freq=conf.wandb.log_freq)
    
    callbacks = [LearningRateMonitor()]
    if conf.get('offline', False):
        os.system(f'rm -r checkpoints/{exp_name}')
        callbacks.append(ModelCheckpoint(
            dirpath=f'checkpoints/{exp_name}',
            filename='epoch{epoch}-val_psnr{val/psnr:.2f}',
            save_top_k=3,
            verbose=True,
            monitor='val/psnr',
            mode='max',
            auto_insert_metric_name=False
        ))
    trainer = pl.Trainer(
        logger=wandb_logger, max_epochs=conf.train.num_epochs, 
        gpus=conf.gpus, **conf.train.get('trainer_args', {}), callbacks=callbacks
    )
    trainer.fit(model)
    wandb.finish()

if __name__ == '__main__':
    cli_conf = OmegaConf.from_cli()
    conf = OmegaConf.load(cli_conf.conf_path)
    conf = OmegaConf.merge(conf, cli_conf)
    data_conf = OmegaConf.load(conf.get('data_conf_path', 'data_configs/cow.yaml'))
    conf = OmegaConf.merge(conf, data_conf)
    run_conf(conf)